// Copyright (c) 2021 Anatoly Ikorsky
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.

use std::{
    borrow::Cow,
    cmp::min,
    convert::TryFrom,
    fmt,
    io::{self, Read},
};

use byteorder::{LittleEndian, ReadBytesExt};
use saturating::Saturating as S;

use crate::{
    binlog::{
        consts::{BinlogVersion, EventType, StatusVarKey},
        BinlogCtx, BinlogEvent, BinlogStruct,
    },
    constants::{Flags2, SqlMode},
    io::ParseBuf,
    misc::{
        raw::{
            bytes::{BareU16Bytes, BareU8Bytes, EofBytes, NullBytes, U8Bytes},
            int::*,
            RawBytes, RawFlags, Skip,
        },
        unexpected_buf_eof,
    },
    proto::{MyDeserialize, MySerialize},
};

use super::BinlogEventHeader;

/// A query event is created for each query that modifies the database, unless the query
/// is logged row-based.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct QueryEvent<'a> {
    // post-header fields
    /// The ID of the thread that issued this statement. It is needed for temporary tables.
    thread_id: RawInt<LeU32>,
    /// The time from when the query started to when it was logged in the binlog, in seconds.
    execution_time: RawInt<LeU32>,
    schema_len: RawInt<u8>,
    /// Error code generated by the master. If the master fails, the slave will fail with
    /// the same error code.
    error_code: RawInt<LeU16>,
    status_vars_len: RawInt<LeU16>,

    // payload
    /// Zero or more status variables (`status_vars_length` bytes).
    ///
    /// Each status variable consists of one byte identifying the variable stored, followed
    /// by the value of the variable. Please consult the MySql documentation.
    ///
    /// Only available if binlog version >= 4 (empty otherwise).
    status_vars: StatusVars<'a>,
    /// The currently selected database name (`schema-length` bytes).
    schema: RawBytes<'a, BareU8Bytes>,
    __skip: Skip<1>,
    /// The SQL query.
    query: RawBytes<'a, EofBytes>,
}

impl<'a> QueryEvent<'a> {
    /// Creates a new instance.
    pub fn new(status_vars: impl Into<Cow<'a, [u8]>>, schema: impl Into<Cow<'a, [u8]>>) -> Self {
        let status_vars = StatusVars(RawBytes::new(status_vars));
        let schema = RawBytes::new(schema);
        Self {
            thread_id: Default::default(),
            execution_time: Default::default(),
            schema_len: RawInt::new(schema.len() as u8),
            error_code: Default::default(),
            status_vars_len: RawInt::new(status_vars.0.len() as u16),
            status_vars,
            schema,
            __skip: Default::default(),
            query: Default::default(),
        }
    }

    /// Sets the `thread_id` value.
    pub fn with_thread_id(mut self, thread_id: u32) -> Self {
        self.thread_id = RawInt::new(thread_id);
        self
    }

    /// Sets the `execution_time` value.
    pub fn with_execution_time(mut self, execution_time: u32) -> Self {
        self.execution_time = RawInt::new(execution_time);
        self
    }

    /// Sets the `error_code` value.
    pub fn with_error_code(mut self, error_code: u16) -> Self {
        self.error_code = RawInt::new(error_code);
        self
    }

    /// Sets the `status_vars` value (max length is `u16::MAX).
    pub fn with_status_vars(mut self, status_vars: impl Into<Cow<'a, [u8]>>) -> Self {
        self.status_vars = StatusVars(RawBytes::new(status_vars));
        self.status_vars_len.0 = self.status_vars.0.len() as u16;
        self
    }

    /// Sets the `schema` value (max length is `u8::MAX).
    pub fn with_schema(mut self, schema: impl Into<Cow<'a, [u8]>>) -> Self {
        self.schema = RawBytes::new(schema);
        self.schema_len.0 = self.schema.len() as u8;
        self
    }

    /// Sets the `query` value.
    pub fn with_query(mut self, query: impl Into<Cow<'a, [u8]>>) -> Self {
        self.query = RawBytes::new(query);
        self
    }

    /// Returns the `thread_id` value.
    ///
    /// `thread_id` is the ID of the thread that issued this statement.
    /// It is needed for temporary tables.
    pub fn thread_id(&self) -> u32 {
        self.thread_id.0
    }

    /// Returns the `execution_time` value.
    ///
    /// `execution_time` is the time from when the query started to when it was logged
    /// in the binlog, in seconds.
    pub fn execution_time(&self) -> u32 {
        self.execution_time.0
    }

    /// Returns the `error_code` value.
    ///
    /// `error_code` is the error code generated by the master. If the master fails, the slave will
    /// fail with the same error code.
    pub fn error_code(&self) -> u16 {
        self.error_code.0
    }

    /// Returns the `status_vars` value.
    ///
    /// `status_vars` contains zero or more status variables. Each status variable consists of one
    /// byte identifying the variable stored, followed by the value of the variable.
    pub fn status_vars_raw(&'a self) -> &'a [u8] {
        self.status_vars.0.as_bytes()
    }

    /// Returns an iterator over status variables.
    pub fn status_vars(&'a self) -> &'a StatusVars<'a> {
        &self.status_vars
    }

    /// Returns the `schema` value.
    ///
    /// `schema` is schema name.
    pub fn schema_raw(&'a self) -> &'a [u8] {
        self.schema.as_bytes()
    }

    /// Returns the `schema` value as a string (lossy converted).
    pub fn schema(&'a self) -> Cow<'a, str> {
        self.schema.as_str()
    }

    /// Returns the `query` value.
    ///
    /// `query` is the corresponding LOAD DATA INFILE statement.
    pub fn query_raw(&'a self) -> &'a [u8] {
        self.query.as_bytes()
    }

    /// Returns the `query` value as a string (lossy converted).
    pub fn query(&'a self) -> Cow<'a, str> {
        self.query.as_str()
    }

    pub fn into_owned(self) -> QueryEvent<'static> {
        QueryEvent {
            thread_id: self.thread_id,
            execution_time: self.execution_time,
            schema_len: self.schema_len,
            error_code: self.error_code,
            status_vars_len: self.status_vars_len,
            status_vars: self.status_vars.into_owned(),
            schema: self.schema.into_owned(),
            __skip: self.__skip,
            query: self.query.into_owned(),
        }
    }
}

impl<'de> MyDeserialize<'de> for QueryEvent<'de> {
    const SIZE: Option<usize> = None;
    type Ctx = BinlogCtx<'de>;

    fn deserialize(ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
        let mut sbuf: ParseBuf = buf.parse(13)?;
        let thread_id = sbuf.parse_unchecked(())?;
        let execution_time = sbuf.parse_unchecked(())?;
        let schema_len: RawInt<u8> = sbuf.parse_unchecked(())?;
        let error_code = sbuf.parse_unchecked(())?;
        let status_vars_len: RawInt<LeU16> = sbuf.parse_unchecked(())?;

        let post_header_len = ctx.fde.get_event_type_header_length(Self::EVENT_TYPE);
        if !buf.checked_skip(post_header_len.saturating_sub(13) as usize) {
            return Err(unexpected_buf_eof());
        }

        let status_vars = buf.parse(*status_vars_len)?;
        let schema = buf.parse(*schema_len as usize)?;
        let __skip = buf.parse(())?;
        let query = buf.parse(())?;

        Ok(Self {
            thread_id,
            execution_time,
            schema_len,
            error_code,
            status_vars_len,
            status_vars,
            schema,
            __skip,
            query,
        })
    }
}

impl MySerialize for QueryEvent<'_> {
    fn serialize(&self, buf: &mut Vec<u8>) {
        self.thread_id.serialize(&mut *buf);
        self.execution_time.serialize(&mut *buf);
        self.schema_len.serialize(&mut *buf);
        self.error_code.serialize(&mut *buf);
        self.status_vars_len.serialize(&mut *buf);
        self.status_vars.serialize(&mut *buf);
        self.schema.serialize(&mut *buf);
        self.__skip.serialize(&mut *buf);
        self.query.serialize(&mut *buf);
    }
}

impl<'a> BinlogEvent<'a> for QueryEvent<'a> {
    const EVENT_TYPE: EventType = EventType::QUERY_EVENT;
}

impl<'a> BinlogStruct<'a> for QueryEvent<'a> {
    fn len(&self, _version: BinlogVersion) -> usize {
        let mut len = S(0);

        len += S(4);
        len += S(4);
        len += S(1);
        len += S(2);
        len += S(2);
        len += S(min(self.status_vars.0.len(), u16::MAX as usize));
        len += S(min(self.schema.0.len(), u8::MAX as usize));
        len += S(1);
        len += S(self.query.0.len());

        min(len.0, u32::MAX as usize - BinlogEventHeader::LEN)
    }
}

/// Status variable value.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum StatusVarVal<'a> {
    Flags2(RawFlags<Flags2, LeU32>),
    SqlMode(RawFlags<SqlMode, LeU64>),
    /// Ignored by this implementation.
    Catalog(&'a [u8]),
    AutoIncrement {
        increment: u16,
        offset: u16,
    },
    Charset {
        charset_client: u16,
        collation_connection: u16,
        collation_server: u16,
    },
    /// Will be empty if timezone length is `0`.
    TimeZone(RawBytes<'a, U8Bytes>),
    /// Will be empty if timezone length is `0`.
    CatalogNz(RawBytes<'a, U8Bytes>),
    LcTimeNames(u16),
    CharsetDatabase(u16),
    TableMapForUpdate(u64),
    MasterDataWritten([u8; 4]),
    Invoker {
        username: RawBytes<'a, U8Bytes>,
        hostname: RawBytes<'a, U8Bytes>,
    },
    UpdatedDbNames(Vec<RawBytes<'a, NullBytes>>),
    Microseconds(u32),
    /// Ignored.
    CommitTs(&'a [u8]),
    /// Ignored.
    CommitTs2(&'a [u8]),
    /// `0` is interpreted as `false` and everything else as `true`.
    ExplicitDefaultsForTimestamp(bool),
    DdlLoggedWithXid(u64),
    DefaultCollationForUtf8mb4(u16),
    SqlRequirePrimaryKey(u8),
    DefaultTableEncryption(u8),
}

/// Raw status variable.
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVar<'a> {
    /// Status variable key.
    key: StatusVarKey,
    /// Raw value of a status variable. Use `Self::get_value`.
    value: &'a [u8],
}

impl StatusVar<'_> {
    /// Returns parsed value of this status variable, or raw value in case of error.
    pub fn get_value(&self) -> Result<StatusVarVal, &[u8]> {
        match self.key {
            StatusVarKey::Flags2 => {
                let mut read = self.value;
                read.read_u32::<LittleEndian>()
                    .map(RawFlags::new)
                    .map(StatusVarVal::Flags2)
                    .map_err(|_| self.value)
            }
            StatusVarKey::SqlMode => {
                let mut read = self.value;
                read.read_u64::<LittleEndian>()
                    .map(RawFlags::new)
                    .map(StatusVarVal::SqlMode)
                    .map_err(|_| self.value)
            }
            StatusVarKey::Catalog => Ok(StatusVarVal::Catalog(self.value)),
            StatusVarKey::AutoIncrement => {
                let mut read = self.value;
                let increment = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                let offset = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::AutoIncrement { increment, offset })
            }
            StatusVarKey::Charset => {
                let mut read = self.value;
                let charset_client = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                let collation_connection =
                    read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                let collation_server = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::Charset {
                    charset_client,
                    collation_connection,
                    collation_server,
                })
            }
            StatusVarKey::TimeZone => {
                let mut read = self.value;
                let len = read.read_u8().map_err(|_| self.value)? as usize;
                let text = read.get(..len).ok_or(self.value)?;
                Ok(StatusVarVal::TimeZone(RawBytes::new(text)))
            }
            StatusVarKey::CatalogNz => {
                let mut read = self.value;
                let len = read.read_u8().map_err(|_| self.value)? as usize;
                let text = read.get(..len).ok_or(self.value)?;
                Ok(StatusVarVal::CatalogNz(RawBytes::new(text)))
            }
            StatusVarKey::LcTimeNames => {
                let mut read = self.value;
                let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::LcTimeNames(val))
            }
            StatusVarKey::CharsetDatabase => {
                let mut read = self.value;
                let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::CharsetDatabase(val))
            }
            StatusVarKey::TableMapForUpdate => {
                let mut read = self.value;
                let val = read.read_u64::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::TableMapForUpdate(val))
            }
            StatusVarKey::MasterDataWritten => {
                let mut read = self.value;
                let mut val = [0u8; 4];
                read.read_exact(&mut val).map_err(|_| self.value)?;
                Ok(StatusVarVal::MasterDataWritten(val))
            }
            StatusVarKey::Invoker => {
                let mut read = self.value;

                let len = read.read_u8().map_err(|_| self.value)? as usize;
                let username = read.get(..len).ok_or(self.value)?;
                read = &read[len..];

                let len = read.read_u8().map_err(|_| self.value)? as usize;
                let hostname = read.get(..len).ok_or(self.value)?;

                Ok(StatusVarVal::Invoker {
                    username: RawBytes::new(username),
                    hostname: RawBytes::new(hostname),
                })
            }
            StatusVarKey::UpdatedDbNames => {
                let mut read = self.value;
                let count = read.read_u8().map_err(|_| self.value)? as usize;
                let mut names = Vec::with_capacity(count);

                for _ in 0..count {
                    let index = read.iter().position(|x| *x == 0).ok_or(self.value)?;
                    names.push(RawBytes::new(&read[..index]));
                    read = &read[index..];
                }

                Ok(StatusVarVal::UpdatedDbNames(names))
            }
            StatusVarKey::Microseconds => {
                let mut read = self.value;
                let val = read.read_u32::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::Microseconds(val))
            }
            StatusVarKey::CommitTs => Ok(StatusVarVal::CommitTs(self.value)),
            StatusVarKey::CommitTs2 => Ok(StatusVarVal::CommitTs2(self.value)),
            StatusVarKey::ExplicitDefaultsForTimestamp => {
                let mut read = self.value;
                let val = read.read_u8().map_err(|_| self.value)?;
                Ok(StatusVarVal::ExplicitDefaultsForTimestamp(val != 0))
            }
            StatusVarKey::DdlLoggedWithXid => {
                let mut read = self.value;
                let val = read.read_u64::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::DdlLoggedWithXid(val))
            }
            StatusVarKey::DefaultCollationForUtf8mb4 => {
                let mut read = self.value;
                let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
                Ok(StatusVarVal::DefaultCollationForUtf8mb4(val))
            }
            StatusVarKey::SqlRequirePrimaryKey => {
                let mut read = self.value;
                let val = read.read_u8().map_err(|_| self.value)?;
                Ok(StatusVarVal::SqlRequirePrimaryKey(val))
            }
            StatusVarKey::DefaultTableEncryption => {
                let mut read = self.value;
                let val = read.read_u8().map_err(|_| self.value)?;
                Ok(StatusVarVal::DefaultTableEncryption(val))
            }
        }
    }
}

impl fmt::Debug for StatusVar<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("StatusVar")
            .field("key", &self.key)
            .field("value", &self.get_value())
            .finish()
    }
}

/// Status variables of a QueryEvent.
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVars<'a>(pub RawBytes<'a, BareU16Bytes>);

impl<'a> StatusVars<'a> {
    /// Returns an iterator over QueryEvent status variables.
    pub fn iter(&'a self) -> StatusVarsIterator<'a> {
        StatusVarsIterator::new(self.0.as_bytes())
    }

    /// Returns raw value of a status variable by key.
    pub fn get_status_var(&'a self, needle: StatusVarKey) -> Option<StatusVar<'a>> {
        self.iter()
            .find_map(|var| if var.key == needle { Some(var) } else { None })
    }

    pub fn into_owned(self) -> StatusVars<'static> {
        StatusVars(self.0.into_owned())
    }
}

impl<'de> MyDeserialize<'de> for StatusVars<'de> {
    const SIZE: Option<usize> = None;
    type Ctx = u16;

    fn deserialize(len: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
        Ok(Self(buf.parse(len as usize)?))
    }
}

impl MySerialize for StatusVars<'_> {
    fn serialize(&self, buf: &mut Vec<u8>) {
        self.0.serialize(buf);
    }
}

impl fmt::Debug for StatusVars<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.iter().fmt(f)
    }
}

/// Iterator over status vars of a `QueryEvent`.
///
/// It will stop iteration if vars can't be parsed.
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVarsIterator<'a> {
    pos: usize,
    status_vars: &'a [u8],
}

impl<'a> StatusVarsIterator<'a> {
    /// Creates new instance.
    pub fn new(status_vars: &'a [u8]) -> StatusVarsIterator<'a> {
        Self {
            pos: 0,
            status_vars,
        }
    }
}

impl fmt::Debug for StatusVarsIterator<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_list().entries(self.clone()).finish()
    }
}

impl<'a> Iterator for StatusVarsIterator<'a> {
    type Item = StatusVar<'a>;

    fn next(&mut self) -> Option<Self::Item> {
        let key = *self.status_vars.get(self.pos)?;
        let key = StatusVarKey::try_from(key).ok()?;
        self.pos += 1;

        macro_rules! get_fixed {
            ($len:expr) => {{
                self.pos += $len;
                self.status_vars.get((self.pos - $len)..self.pos)?
            }};
        }

        macro_rules! get_var {
            ($suffix_len:expr) => {{
                let len = *self.status_vars.get(self.pos)? as usize;
                get_fixed!(1 + len + $suffix_len)
            }};
        }

        let value = match key {
            StatusVarKey::Flags2 => get_fixed!(4),
            StatusVarKey::SqlMode => get_fixed!(8),
            StatusVarKey::Catalog => get_var!(1),
            StatusVarKey::AutoIncrement => get_fixed!(4),
            StatusVarKey::Charset => get_fixed!(6),
            StatusVarKey::TimeZone => get_var!(0),
            StatusVarKey::CatalogNz => get_var!(0),
            StatusVarKey::LcTimeNames => get_fixed!(2),
            StatusVarKey::CharsetDatabase => get_fixed!(2),
            StatusVarKey::TableMapForUpdate => get_fixed!(8),
            StatusVarKey::MasterDataWritten => get_fixed!(4),
            StatusVarKey::Invoker => {
                let user_len = *self.status_vars.get(self.pos)? as usize;
                let host_len = *self.status_vars.get(self.pos + 1 + user_len)? as usize;
                get_fixed!(1 + user_len + 1 + host_len)
            }
            StatusVarKey::UpdatedDbNames => {
                let mut total = 1;
                let count = *self.status_vars.get(self.pos)? as usize;
                for _ in 0..count {
                    while *self.status_vars.get(self.pos + total)? != 0x00 {
                        total += 1;
                    }
                    total += 1;
                }
                get_fixed!(total)
            }
            StatusVarKey::Microseconds => get_fixed!(3),
            StatusVarKey::CommitTs => get_fixed!(0),
            StatusVarKey::CommitTs2 => get_fixed!(0),
            StatusVarKey::ExplicitDefaultsForTimestamp => get_fixed!(1),
            StatusVarKey::DdlLoggedWithXid => get_fixed!(8),
            StatusVarKey::DefaultCollationForUtf8mb4 => get_fixed!(2),
            StatusVarKey::SqlRequirePrimaryKey => get_fixed!(1),
            StatusVarKey::DefaultTableEncryption => get_fixed!(1),
        };

        Some(StatusVar { key, value })
    }
}
